#pragma once

#include <cooperative_groups.h>

#include "gn_utils.hpp"

namespace group_norm_v2 {

namespace cg = cooperative_groups;

template <typename T>
inline constexpr T up_div(T a, T b) {
  return (a + b - 1) / b;
}

template <typename T>
inline constexpr T round_up(T a, T b) {
  return up_div(a, b) * b;
}

inline constexpr unsigned round_up_pow2(unsigned x) {
  int log = 0;
  x--;
  while (x) {
    x /= 2;
    log++;
  }
  return 1U << log;
}

inline constexpr unsigned round_down_pow2(unsigned x) { return round_up_pow2(x + 1) / 2; }

template <typename T>
inline constexpr T gcd(T a, T b) {
  while (b != 0) {
    int t = b;
    b = a % b;
    a = t;
  }
  return a;
}

template <typename T>
inline constexpr T lcm(T a, T b) {
  return (a * b) / gcd(a, b);
}

template <typename T>
inline constexpr T relative_prime(T x, T min) {
  int p = min;
  while (gcd(p, x) != 1) {
    p++;
  }
  return p;
}

template <typename T>
inline constexpr T max_divisor(T x, T max) {
  int p = max;
  while (x % p != 0) {
    p--;
  }
  return p;
}

constexpr unsigned FINAL_MASK = 0xffffffff;

template <int VIRTUAL_CLUSTER_SIZE, bool PERSISTENT, bool HARDWARE_CLUSTER>
__device__ void virtual_cluster_sync(unsigned int *barrier) {
  if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {
    __syncthreads();
  } else if constexpr (HARDWARE_CLUSTER) {
    cg::this_cluster().sync();
  } else {
    static_assert(PERSISTENT, "potential deadlock");
    volatile unsigned int *arrived = &barrier[blockIdx.y];
    __syncthreads();
    if (threadIdx.x == 0) {
      unsigned int expected = VIRTUAL_CLUSTER_SIZE;
      bool gpu_master = blockIdx.x == 0;
      unsigned int nb = 1;
      if (gpu_master) {
        nb = 0x80000000 - (expected - 1);
      }
      unsigned int oldArrive;
      asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;"
                   : "=r"(oldArrive)
                   : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived), "r"(nb)
                   : "memory");
      unsigned int current_arrive;
      do {
        asm volatile("ld.acquire.gpu.u32 %0,[%1];"
                     : "=r"(current_arrive)
                     : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived)
                     : "memory");
      } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive));
    }
    __syncthreads();
  }
}

template <int NUM_BLOCKS, bool PERSISTENT>
__device__ unsigned int group_barrier_arrive(unsigned int *barrier, bool gpu_master) {
  static_assert(PERSISTENT, "potential deadlock");
  volatile unsigned int *arrived = &barrier[0];
  __syncthreads();
  if (threadIdx.x == 0) {
    unsigned int expected = NUM_BLOCKS;
    unsigned int nb = 1;
    if (gpu_master) {
      nb = 0x80000000 - (expected - 1);
    }
    unsigned int oldArrive;
    asm volatile("atom.add.release.gpu.u32 %0,[%1],%2;"
                 : "=r"(oldArrive)
                 : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived), "r"(nb)
                 : "memory");
    return oldArrive;
  } else {
    return 0;
  }
}

__device__ inline void group_barrier_wait(unsigned int *barrier, unsigned int oldArrive) {
  volatile unsigned int *arrived = &barrier[0];
  if (threadIdx.x == 0) {
    unsigned int current_arrive;
    do {
      asm volatile("ld.acquire.gpu.u32 %0,[%1];"
                   : "=r"(current_arrive)
                   : _CG_ASM_PTR_CONSTRAINT((unsigned int *)arrived)
                   : "memory");
    } while (!cooperative_groups::details::bar_has_flipped(oldArrive, current_arrive));
  }
  __syncthreads();
}

// Calculate `n` (batch id) and `c` (channel range id) for each loop
template <bool CONSTANT_C_LOOP, int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>
class NCScheduler;

template <int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>
class NCScheduler<false, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> {
 public:
  __device__ NCScheduler(int64_t n) {
    nc_loop_ = blockIdx.y;
    at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER);
  }
  __device__ auto get_nc() {
    int64_t n_loop = nc_loop_ / (C / C_PER_CLUSTER);
    int c_loop = nc_loop_ % (C / C_PER_CLUSTER);
    return std::make_tuple(n_loop, c_loop);
  }
  __device__ void next(int64_t n) {
    if constexpr (PERSISTENT) {
      nc_loop_ += NUM_VIRTUAL_CLUSTERS;
      at_end_ = nc_loop_ >= n * (C / C_PER_CLUSTER);
    }
  }
  __device__ bool at_end(int64_t n) { return !PERSISTENT || at_end_; }

 private:
  int64_t nc_loop_;
  bool at_end_;
};

template <int C, int C_PER_CLUSTER, int NUM_VIRTUAL_CLUSTERS, bool PERSISTENT>
class NCScheduler<true, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> {
 public:
  __device__ NCScheduler(int64_t n) {
    n_loop_ = blockIdx.y / (C / C_PER_CLUSTER);
    c_loop_ = blockIdx.y % (C / C_PER_CLUSTER);
  }
  __device__ auto get_nc() { return std::make_tuple(n_loop_, c_loop_); }
  __device__ void next(int64_t n) {
    if constexpr (PERSISTENT) {
      n_loop_ += NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER);
    }
  }
  __device__ bool at_end(int64_t n) { return !PERSISTENT || n_loop_ >= n; }

 private:
  int64_t n_loop_;
  int c_loop_;
};

class CompileConditionAlwaysTrue {
 public:
  __device__ static constexpr bool matches() { return true; }
};

template <typename T, int BLOCK_DIM_X, int BLOCKS_PER_SM, int G, int CPG, int HW, bool SILU, int ROWS_PER_BLOCK,
          int C_PER_BLOCK, int C_PER_CLUSTER, int VEC_ELEMS, bool PERSISTENT, int NUM_VIRTUAL_CLUSTERS, bool LOAD_TWICE,
          bool HARDWARE_CLUSTER, class CompileCondition = CompileConditionAlwaysTrue>
__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_cuda_kernel(
    T *__restrict__ out, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b, float eps,
    int64_t n, float *__restrict__ mean_var_out, float *__restrict__ red_buffer, unsigned *__restrict__ barrier) {
  // Procedure Overview
  //   1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE)
  //   2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is
  //   used)
  //   3. Group sum: read from gmem, write mean&var to smem
  //   4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem

  static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error");

  constexpr int C = G * CPG;
  static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters");
  static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks");
  static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results");
  static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK),
                "inefficient configuration, please reduce C_PER_CLUSTER");

  static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads");
  struct alignas(VEC_ELEMS * sizeof(T)) U {
    T data[VEC_ELEMS];
  };

  auto compute_mean_var = [&](float2 sum) {
    float mean = sum.x / (HW * CPG);
    float var = std::max(0.f, sum.y / (HW * CPG) - mean * mean);
    return float2{mean, var};
  };

  static_assert(HW % ROWS_PER_BLOCK == 0,
                "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis");
  constexpr int MAX_NUM_GROUPS_PER_BLOCK =
      C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;
  constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);
  constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK;
  constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK;
  int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x;
  int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x;

  if constexpr (CompileCondition::matches()) {
    int step = 0;
    constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0;
    NCScheduler<CONSTANT_C_LOOP, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> nc_scheduler(n);
    while (true) {  // TODO: unroll the loop
      if constexpr (PERSISTENT) {
        if (nc_scheduler.at_end(n)) {
          break;
        }
      }
      auto [n_loop, c_loop] = nc_scheduler.get_nc();
      if constexpr (PERSISTENT) {
        nc_scheduler.next(n);
      }
      static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize");
      static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0,
                    "each block should load one or more C_PER_BLOCK at once");
      constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK;
      static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch");
      int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;
      int block_group_start = block_channel_start / CPG;
      int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS;
      U frag[ROWS_PER_BLOCK / ROWS_PER_IO];

      // GCD_VEC_CPG is an important constant that determines how many channels can be merged in reduction computation
      //   For example, VEC_ELEMS=4 and CPG=10, then GCD_VEC_CPG=2,
      //   so we need to store only 2 sums on each thread, and compute only 2 mean&var for each thread.
      constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG);

      // If each block handles only one group, run warpReduce and store the sum to `sum_per_channel_single_group`;
      // otherwise store (VEC_ELEMS / GCD_VEC_CPG) sums to `sum_per_channel_multi_group`, where `relative_prime` is used
      // for swizzle.
      constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0;
      [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32];
      [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(
          128 / (int)sizeof(float2), ROWS_PER_IO)];

      if constexpr (LOAD_TWICE) {
        float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{};
        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
          int64_t input_idx =
              n_loop * HW * C +
              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
              thread_channel_start;
          U val = *reinterpret_cast<U const *>(&x[input_idx]);
          for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
            float2 sum = frag_sum_per_channel[i];
            for (int k = 0; k < GCD_VEC_CPG; k++) {
              sum.x += (float)val.data[i * GCD_VEC_CPG + k];
              sum.y += (float)val.data[i * GCD_VEC_CPG + k] * (float)val.data[i * GCD_VEC_CPG + k];
            }
            frag_sum_per_channel[i] = sum;
          }
        }
        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
          if constexpr (SINGLE_GROUP_PER_BLOCK) {
            for (int mask = 16; mask > 0; mask >>= 1) {
              frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32);
              frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32);
            }
            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp");
            if (threadIdx.x % 32 == 0) {
              sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i];
            }
          } else {
            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]
                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i];
          }
        }
        __syncthreads();
      } else {
        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
          int64_t input_idx =
              n_loop * HW * C +
              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
              thread_channel_start;
          frag[j] = *reinterpret_cast<U const *>(&x[input_idx]);
        }

        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
          float2 sum = {0.f, 0.f};
          for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
            for (int k = 0; k < GCD_VEC_CPG; k++) {
              sum.x += (float)frag[j].data[i * GCD_VEC_CPG + k];
              sum.y += (float)frag[j].data[i * GCD_VEC_CPG + k] * (float)frag[j].data[i * GCD_VEC_CPG + k];
            }
          }
          if constexpr (SINGLE_GROUP_PER_BLOCK) {
            for (int mask = 16; mask > 0; mask >>= 1) {
              sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32);
              sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32);
            }
            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp");
            if (threadIdx.x % 32 == 0) {
              sum_per_channel_single_group[threadIdx.x / 32] = sum;
            }
          } else {
            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]
                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum;
          }
        }
        __syncthreads();
      }

      U uw = *reinterpret_cast<U const *>(&w[thread_channel_start]);
      U ub = *reinterpret_cast<U const *>(&b[thread_channel_start]);

      // Three cases for the red_buffer:
      //   - Block sync (VIRTUAL_CLUSTER_SIZE=1): use shared memory
      //   - Virtual cluster sync with HARDWARE_CLUSTER: use distributed shared memory
      //   - Virtual cluster sync without HARDWARE_CLUSTER: use global memory, i.e., `red_buffer`
      constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1;

      // Specialize for the case that each group is handled by only one block
      //   For common cases, blockSum produces partial sum and stores it to the red_buffer, and groupSum produces
      //   mean&var For the special case, blockSum produces mean&var directly
      constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER =
          VIRTUAL_CLUSTER_SIZE == 1 &&
          MAX_NUM_GROUPS_PER_BLOCK == 1;  // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented

      [[maybe_unused]] __align__(16)
          __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)];

      // Block sum
      if constexpr (SINGLE_GROUP_PER_BLOCK) {
        // block reduce
        if (threadIdx.x < 32) {
          float2 sum_local_group =
              threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f};
          constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);
          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {
            sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);
            sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);
          }
          if (threadIdx.x == 0) {
            if constexpr (USE_SHARED_RED_BUFFER) {
              if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
                shared_red_buffer[0] = compute_mean_var(sum_local_group);
              } else {
                shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group;
              }
            } else {
              *reinterpret_cast<float2 *>(
                  &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                               virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                               // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y +
                               virtual_block_idx_y) *
                              2]) = sum_local_group;
            }
          }
        }
      } else {
        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)
        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)),
                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));
        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads");
        float2 sum_local_group = {0.f, 0.f};
        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;
          // TODO: map threads to both the CPG loop and the ROWS loop
          for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) {
            int c = local_group_idx * CPG + local_c_loop;
            if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) {
              for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO;
                   src_thread_tile_y += THREADS_PER_GROUP) {
                int channel_idx = (c - block_channel_start) / GCD_VEC_CPG;
                channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) +
                              channel_idx / (VEC_ELEMS / GCD_VEC_CPG);
                sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x;
                sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y;
              }
            }
          }
        }
        static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle");
        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {
          sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);
          sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);
        }
        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          if constexpr (USE_SHARED_RED_BUFFER) {
            static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory");
            if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
              shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group);
            } else {
              shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group;
            }
          } else {
            *reinterpret_cast<float2 *>(
                &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                             virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                             (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) *
                            2]) = sum_local_group;
          }
        }
      }

      virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);

      // Group sum
      __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK];
      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)
        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)),
                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));
        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads");
        float2 sum_global_group = {0.f, 0.f};
        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          if constexpr (C_PER_BLOCK % CPG == 0) {
            // Special case: no cross-virtual_cluster_dim_x reduction
            float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)];
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {
              float2 val;
              if constexpr (USE_SHARED_RED_BUFFER) {
                if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {
                  val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];
                } else {
                  static_assert(HARDWARE_CLUSTER, "no distributed shared memory");
                  float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(
                      shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x);
                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];
                }
              } else {
                val = *reinterpret_cast<float2 const *>(
                    &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                                 virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                                 (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) *
                                2]);
              }
              buffer[i / THREADS_PER_GROUP] = val;
            }
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {
              float2 val = buffer[i / THREADS_PER_GROUP];
              sum_global_group.x += val.x;
              sum_global_group.y += val.y;
            }
          } else {
            // Common case: cross-virtual_cluster_dim_x reduction
            int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) {
              int src_virtual_block_idx_x = i % virtual_cluster_dim_x;
              int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;
              int src_block_group_start = src_block_channel_start / CPG;
              int relative_group_idx = local_group_idx - src_block_group_start;
              if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) {
                float2 val;
                if constexpr (USE_SHARED_RED_BUFFER) {
                  static_assert(HARDWARE_CLUSTER, "no distributed shared memory");
                  static_assert(VIRTUAL_CLUSTER_SIZE != 1,
                                "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)");
                  float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i);
                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx];
                } else {
                  val = *reinterpret_cast<float2 const *>(
                      &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                                   src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                                   relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) *
                                  2]);
                }
                sum_global_group.x += val.x;
                sum_global_group.y += val.y;
              }
            }
          }
        }
        if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {
          // Need cluster sync after distributed shared memory access, otherwise behavior is undefined
          if constexpr (PERSISTENT) {
            if (nc_scheduler.at_end(n)) {
              cg::this_cluster().barrier_arrive();
            }
          } else {
            cg::this_cluster().barrier_arrive();
          }
        }
        static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle");
        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {
          sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32);
          sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32);
        }
        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group);
        }
        __syncthreads();
      }

      auto get_mean_var = [&](int relative_group_idx) {
        return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx]
                                                   : mean_var[relative_group_idx];
      };

      if (mean_var_out) {
        static_assert(MAX_NUM_GROUPS_PER_BLOCK <= BLOCK_DIM_X, "need loop");
        if (virtual_block_idx_y == 0 && threadIdx.x < MAX_NUM_GROUPS_PER_BLOCK) {
          int g = block_group_start + threadIdx.x;
          if (C_PER_BLOCK % CPG == 0 || g < G) {
            *reinterpret_cast<float2 *>(&mean_var_out[(n_loop * G + g) * 2]) = get_mean_var(threadIdx.x);
          }
        }
      }

      float frag_mean[VEC_ELEMS / GCD_VEC_CPG];
      float frag_var[VEC_ELEMS / GCD_VEC_CPG];
      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {
        frag_mean[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x;
        frag_var[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y;
      }

      for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
        int64_t input_idx =
            n_loop * HW * C +
            (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
            thread_channel_start;
        U val;
        if constexpr (LOAD_TWICE) {
          val = *reinterpret_cast<U const *>(&x[input_idx]);
        } else {
          val = frag[j];
        }
        for (int k = 0; k < VEC_ELEMS; k++) {
          float f = ((float)val.data[k] - frag_mean[k / GCD_VEC_CPG]) * rsqrtf(frag_var[k / GCD_VEC_CPG] + eps) *
                        (float)uw.data[k] +
                    (float)ub.data[k];
          if constexpr (SILU) f = f / (1.f + expf(-f));
          val.data[k] = f;
        }
        *reinterpret_cast<U *>(&out[input_idx]) = val;
      }

      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {
        if constexpr (PERSISTENT) {
          if (nc_scheduler.at_end(n)) {
            cg::this_cluster().barrier_wait();
          }
        } else {
          cg::this_cluster().barrier_wait();
        }
      }

      if constexpr (!PERSISTENT) {
        break;
      }
      step ^= 1;
    }
  }
}

enum WgradSyncMethod {
  WGRAD_ARRIVE_AND_WAIT_GRID = 0,  // grid arrive after the last virtual cluster sync
  WGRAD_ARRIVE_AND_WAIT_GROUP,     // group arrive after the last virtual cluster sync (a group sync means synchronizing
                                   // all clusters cooperating on the same groups)
  WGRAD_REUSE_SUM_SYNC_GRID,       // grid sync together with the last virtual cluster sync
  WGRAD_REUSE_SUM_SYNC_GROUP,      // group sync together with the last virtual cluster sync
  WGRAD_SYNC_AT_LAST,              // add a sync at the end of NC loops
  WGRAD_SYNC_UNSPECIFIED,
};

template <typename T, int BLOCK_DIM_X, int BLOCKS_PER_SM, int G, int CPG, int HW, bool SILU, bool REQUIRES_WGRAD,
          int ROWS_PER_BLOCK, int C_PER_BLOCK, int C_PER_CLUSTER, int VEC_ELEMS, bool PERSISTENT,
          int NUM_VIRTUAL_CLUSTERS, bool LOAD_TWICE, bool HARDWARE_CLUSTER, WgradSyncMethod wgrad_sync_method,
          class CompileCondition = CompileConditionAlwaysTrue>
__global__ __launch_bounds__(BLOCK_DIM_X, BLOCKS_PER_SM) void gn_bwd_cuda_kernel(
    T *__restrict__ grad_input, T *__restrict__ grad_weight, T *__restrict__ grad_bias,
    T const *__restrict__ grad_output, T const *__restrict__ x, T const *__restrict__ w, T const *__restrict__ b,
    float const *__restrict__ mean_var, float eps, int64_t n, float *__restrict__ red_buffer,
    unsigned *__restrict__ barrier) {
  // Procedure Overview
  //   1. Thread sum: read from gmem, write partial sum to smem, store input in registers (if no LOAD_TWICE)
  //   2. Block sum: read from smem, write partial sum to gmem (or distributed shared memory if HARDWARE_CLUSTER is
  //   used),
  //        write wgrad to gmem at the last loop (at each loop if not CONSTANT_C_LOOP)
  //   3. Group sum: read from gmem, write mean&var to smem
  //   4. Scale: read mean&var from smem, read input from gmem (if LOAD_TWICE), write output to gmem
  //   5. Wgrad sum: read from gmem, write to gmem

  static_assert(BLOCK_DIM_X % 32 == 0, "warp shuffle error");

  constexpr int C = G * CPG;
  static_assert(C % C_PER_CLUSTER == 0, "cannot divide channels into clusters");
  static_assert(C_PER_CLUSTER % C_PER_BLOCK == 0, "cannot divide a cluster into blocks");
  static_assert(C_PER_CLUSTER % CPG == 0, "no reduce between clusters, would produce incorrect results");
  static_assert(!(C_PER_BLOCK % CPG == 0 && C_PER_CLUSTER != C_PER_BLOCK),
                "inefficient configuration, please reduce C_PER_CLUSTER");

  static_assert(ROWS_PER_BLOCK * C_PER_BLOCK % BLOCK_DIM_X == 0, "cannot divide tile into threads");
  struct alignas(VEC_ELEMS * sizeof(T)) U {
    T data[VEC_ELEMS];
  };

  // This function computes mean_dyw and mean_xdyw.
  // The function name is not changed because it has the same logic as the forward pass.
  auto compute_mean_var = [&](float2 sum) {
    float mean_dyw = sum.x / (HW * CPG);
    float mean_xdyw = sum.y / (HW * CPG);
    return float2{mean_dyw, mean_xdyw};
  };

  static_assert(HW % ROWS_PER_BLOCK == 0,
                "HW must be divisible by ROWS_PER_BLOCK to determine the number of blocks on the HW axis");
  constexpr int MAX_NUM_GROUPS_PER_BLOCK =
      C_PER_BLOCK % CPG == 0 ? C_PER_BLOCK / CPG : up_div(C_PER_BLOCK - gcd(C_PER_BLOCK, CPG), CPG) + 1;
  constexpr int VIRTUAL_CLUSTER_SIZE = (C_PER_CLUSTER / C_PER_BLOCK) * (HW / ROWS_PER_BLOCK);
  constexpr int virtual_cluster_dim_x = C_PER_CLUSTER / C_PER_BLOCK;
  constexpr int virtual_cluster_dim_y = HW / ROWS_PER_BLOCK;
  int virtual_block_idx_x = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) % virtual_cluster_dim_x;
  int virtual_block_idx_y = (blockIdx.x % VIRTUAL_CLUSTER_SIZE) / virtual_cluster_dim_x;

  if constexpr (CompileCondition::matches()) {
    int step = 0;
    constexpr bool CONSTANT_C_LOOP = PERSISTENT && NUM_VIRTUAL_CLUSTERS % (C / C_PER_CLUSTER) == 0;
    if constexpr (!CONSTANT_C_LOOP) {
      static_assert(wgrad_sync_method != WGRAD_ARRIVE_AND_WAIT_GROUP && wgrad_sync_method != WGRAD_REUSE_SUM_SYNC_GROUP,
                    "grid sync is required when each block is responsible for multiple channel ranges");
    }
    NCScheduler<false, C, C_PER_CLUSTER, NUM_VIRTUAL_CLUSTERS, PERSISTENT> nc_scheduler(
        n);  // TODO: I don't know why the template specialization with CONSTANT_C_LOOP=true is slower.

    [[maybe_unused]] int virtual_cluster_idx_c = blockIdx.y % (C / C_PER_CLUSTER);
    [[maybe_unused]] cg::grid_group::arrival_token wgrad_sync_token;
    [[maybe_unused]] float dw_thread[VEC_ELEMS];
    [[maybe_unused]] float db_thread[VEC_ELEMS];
    [[maybe_unused]] __shared__ union {
      float2 dwdb_block_buffer[BLOCK_DIM_X][VEC_ELEMS];
      struct {
        float wgrad_buffer[BLOCK_DIM_X / 32][32];
        float bgrad_buffer[BLOCK_DIM_X / 32][32];
      } transpose_buffer;
    } union_smem;
    if constexpr (REQUIRES_WGRAD && CONSTANT_C_LOOP) {
      for (int i = 0; i < VEC_ELEMS; i++) {
        dw_thread[i] = 0.f;
        db_thread[i] = 0.f;
      }
    }
    float *red_buffer_wgrad =
        &red_buffer[(2 * NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK) * 2];
    unsigned *barrier_wgrad = barrier + NUM_VIRTUAL_CLUSTERS;
    if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) {
      if (nc_scheduler.at_end(n)) {
        static_assert(PERSISTENT, "persistent is a must for reducing wgrad");
        if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {
          wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(
              barrier_wgrad, blockIdx.x + blockIdx.y == 0);
        } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {
          wgrad_sync_token =
              group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(
                  barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);
        } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) {
          wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(
              barrier_wgrad, blockIdx.x + blockIdx.y == 0);
          group_barrier_wait(barrier_wgrad, wgrad_sync_token);
        } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) {
          wgrad_sync_token =
              group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(
                  barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);
          group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);
        }
      }
    }

    while (true) {  // TODO: unroll the loop
      if constexpr (PERSISTENT) {
        if (nc_scheduler.at_end(n)) {
          break;
        }
      }
      auto [n_loop, c_loop] = nc_scheduler.get_nc();
      if constexpr (PERSISTENT) {
        nc_scheduler.next(n);
      }
      static_assert(C_PER_BLOCK % VEC_ELEMS == 0, "cannot vectorize");
      static_assert((BLOCK_DIM_X * VEC_ELEMS) % C_PER_BLOCK == 0,
                    "each block should load one or more C_PER_BLOCK at once");
      constexpr int ROWS_PER_IO = BLOCK_DIM_X * VEC_ELEMS / C_PER_BLOCK;
      static_assert(ROWS_PER_BLOCK % ROWS_PER_IO == 0, "cannot determine the IO times per batch");
      int block_channel_start = virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;
      int block_group_start = block_channel_start / CPG;
      int thread_channel_start = block_channel_start + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS) * VEC_ELEMS;
      U frag_x[ROWS_PER_BLOCK / ROWS_PER_IO];
      U frag_dy[ROWS_PER_BLOCK / ROWS_PER_IO];

      constexpr int GCD_VEC_CPG = gcd(VEC_ELEMS, CPG);

      constexpr bool SINGLE_GROUP_PER_BLOCK = CPG % C_PER_BLOCK == 0;
      [[maybe_unused]] __shared__ float2 sum_per_channel_multi_group[C_PER_BLOCK / GCD_VEC_CPG][relative_prime(
          128 / (int)sizeof(float2), ROWS_PER_IO)];
      [[maybe_unused]] __shared__ float2 sum_per_channel_single_group[BLOCK_DIM_X / 32];

      float frag_mean[VEC_ELEMS / GCD_VEC_CPG];
      float frag_var[VEC_ELEMS / GCD_VEC_CPG];
      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {
        float2 value =
            *reinterpret_cast<float2 const *>(&mean_var[(n_loop * G + (thread_channel_start + k) / CPG) * 2]);
        frag_mean[k / GCD_VEC_CPG] = value.x;
        frag_var[k / GCD_VEC_CPG] = value.y;
      }

      U uw = *reinterpret_cast<U const *>(&w[thread_channel_start]);
      U ub;
      if constexpr (SILU) {
        ub = *reinterpret_cast<U const *>(&b[thread_channel_start]);
      }
      if constexpr (REQUIRES_WGRAD && !CONSTANT_C_LOOP) {
        for (int i = 0; i < VEC_ELEMS; i++) {
          dw_thread[i] = 0.f;
          db_thread[i] = 0.f;
        }
      }

      if constexpr (LOAD_TWICE) {
        float2 frag_sum_per_channel[VEC_ELEMS / GCD_VEC_CPG]{};
        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
          int64_t input_idx =
              n_loop * HW * C +
              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
              thread_channel_start;
          U ux = *reinterpret_cast<U const *>(&x[input_idx]);
          U udy = *reinterpret_cast<U const *>(&grad_output[input_idx]);
          for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
            float2 sum = frag_sum_per_channel[i];
            for (int k = 0; k < GCD_VEC_CPG; k++) {
              float rnorm = rsqrtf(frag_var[i] + eps);
              float x_norm =
                  ((float)ux.data[i * GCD_VEC_CPG + k] - frag_mean[i]) * rnorm;  // TODO: store rsqrtf in mean_var
              float grad_gn = udy.data[i * GCD_VEC_CPG + k];
              if constexpr (SILU) {
                float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k];
                float s = 1.f / (1.f + expf(-x_gn));
                grad_gn *= s * (1.f + x_gn * (1.f - s));
              }
              sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k];
              sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]);
              if constexpr (REQUIRES_WGRAD) {
                dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn;
                db_thread[i * GCD_VEC_CPG + k] += grad_gn;
              }
            }
            frag_sum_per_channel[i] = sum;
          }
        }
        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
          if constexpr (SINGLE_GROUP_PER_BLOCK) {
            for (int mask = 16; mask > 0; mask >>= 1) {
              frag_sum_per_channel[i].x += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].x, mask, 32);
              frag_sum_per_channel[i].y += __shfl_xor_sync(FINAL_MASK, frag_sum_per_channel[i].y, mask, 32);
            }
            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp");
            if (threadIdx.x % 32 == 0) {
              sum_per_channel_single_group[threadIdx.x / 32] = frag_sum_per_channel[i];
            }
          } else {
            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]
                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = frag_sum_per_channel[i];
          }
        }
        __syncthreads();
      } else {
        for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
          int64_t input_idx =
              n_loop * HW * C +
              (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
              thread_channel_start;
          frag_x[j] = *reinterpret_cast<U const *>(&x[input_idx]);
          frag_dy[j] = *reinterpret_cast<U const *>(&grad_output[input_idx]);
        }

        for (int i = 0; i < VEC_ELEMS / GCD_VEC_CPG; i++) {
          float2 sum = {0.f, 0.f};
          for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
            for (int k = 0; k < GCD_VEC_CPG; k++) {
              float rnorm = rsqrtf(frag_var[i] + eps);
              float x_norm = ((float)frag_x[j].data[i * GCD_VEC_CPG + k] - frag_mean[i]) *
                             rnorm;  // TODO: store rsqrtf in mean_var
              float grad_gn = frag_dy[j].data[i * GCD_VEC_CPG + k];
              if constexpr (SILU) {
                float x_gn = x_norm * (float)uw.data[i * GCD_VEC_CPG + k] + (float)ub.data[i * GCD_VEC_CPG + k];
                float s = 1.f / (1.f + expf(-x_gn));
                grad_gn *= s * (1.f + x_gn * (1.f - s));
              }
              sum.x += grad_gn * (float)uw.data[i * GCD_VEC_CPG + k];
              sum.y += x_norm * (grad_gn * (float)uw.data[i * GCD_VEC_CPG + k]);
              if constexpr (REQUIRES_WGRAD) {
                dw_thread[i * GCD_VEC_CPG + k] += x_norm * grad_gn;
                db_thread[i * GCD_VEC_CPG + k] += grad_gn;
              }
            }
          }
          if constexpr (SINGLE_GROUP_PER_BLOCK) {
            for (int mask = 16; mask > 0; mask >>= 1) {
              sum.x += __shfl_xor_sync(FINAL_MASK, sum.x, mask, 32);
              sum.y += __shfl_xor_sync(FINAL_MASK, sum.y, mask, 32);
            }
            static_assert(VEC_ELEMS / GCD_VEC_CPG == 1, "process only one element for each warp");
            if (threadIdx.x % 32 == 0) {
              sum_per_channel_single_group[threadIdx.x / 32] = sum;
            }
          } else {
            sum_per_channel_multi_group[i * (C_PER_BLOCK / VEC_ELEMS) + threadIdx.x % (C_PER_BLOCK / VEC_ELEMS)]
                                       [threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)] = sum;
          }
        }
        __syncthreads();
      }

      if ((CONSTANT_C_LOOP && nc_scheduler.at_end(n)) || !CONSTANT_C_LOOP) {
        constexpr int NT_C = max_divisor(C_PER_BLOCK, BLOCK_DIM_X);  // Number of threads on the C axis
        constexpr int NT_R =
            1;  // std::min(32, (int)round_down_pow2(BLOCK_DIM_X / NT_C));  // Number of threads on the ROWS axis
        // TODO: swizzle for NT_R
        for (int i = 0; i < VEC_ELEMS; i++) {
          union_smem.dwdb_block_buffer[threadIdx.x][i ^ ((threadIdx.x / (16 / VEC_ELEMS)) & (VEC_ELEMS - 1))] =
              float2{dw_thread[i], db_thread[i]};
        }
        __syncthreads();
        static_assert(NT_C * NT_R <= BLOCK_DIM_X, "not enough threads");
        static_assert(C_PER_BLOCK % NT_C == 0, "need to loop once more and check c < C_PER_BLOCK");
        for (int i = 0; i < C_PER_BLOCK / NT_C; i++) {
          int c = i * NT_C + threadIdx.x / NT_R;
          float dw_block = 0.f;
          float db_block = 0.f;
          if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) {
            for (int j = threadIdx.x % NT_R; j < ROWS_PER_IO; j += NT_R) {
              int src_thread = j * (C_PER_BLOCK / VEC_ELEMS) + c / VEC_ELEMS;
              float2 val = union_smem.dwdb_block_buffer[src_thread][(c % VEC_ELEMS) ^ ((src_thread / (16 / VEC_ELEMS)) &
                                                                                       (VEC_ELEMS - 1))];
              dw_block += val.x;
              db_block += val.y;
            }
          }
          static_assert(32 % NT_R == 0, "cannot shuffle");
          for (int mask = NT_R / 2; mask > 0; mask >>= 1) {
            dw_block += __shfl_xor_sync(FINAL_MASK, dw_block, mask, 32);
            db_block += __shfl_xor_sync(FINAL_MASK, db_block, mask, 32);
          }
          if (BLOCK_DIM_X == NT_C * NT_R || threadIdx.x < NT_C * NT_R) {
            if (threadIdx.x % NT_R == 0) {
              if constexpr (CONSTANT_C_LOOP) {
                *reinterpret_cast<float2 *>(
                    &red_buffer_wgrad
                        [((blockIdx.y / (C / C_PER_CLUSTER) * virtual_cluster_dim_y + virtual_block_idx_y) * C +
                          c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) *
                         2]) = float2{dw_block, db_block};
              } else {
                *reinterpret_cast<float2 *>(
                    &red_buffer_wgrad[((n_loop * virtual_cluster_dim_y + virtual_block_idx_y) * C +
                                       c_loop * C_PER_CLUSTER + virtual_block_idx_x * C_PER_BLOCK + c) *
                                      2]) = float2{dw_block, db_block};
              }
            }
          }
        }
      }

      constexpr bool USE_SHARED_RED_BUFFER = HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1;
      constexpr bool STORE_MEAN_VAR_IN_SHARED_RED_BUFFER =
          VIRTUAL_CLUSTER_SIZE == 1 &&
          MAX_NUM_GROUPS_PER_BLOCK == 1;  // MAX_NUM_GROUPS_PER_BLOCK > 1 is possible but not implemented
      [[maybe_unused]] __align__(16)
          __shared__ float2 shared_red_buffer[MAX_NUM_GROUPS_PER_BLOCK * (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? 1 : 2)];

      // Block sum
      if constexpr (SINGLE_GROUP_PER_BLOCK) {
        // block reduce
        if (threadIdx.x < 32) {
          float2 sum_local_group =
              threadIdx.x < BLOCK_DIM_X / 32 ? sum_per_channel_single_group[threadIdx.x] : float2{0.f, 0.f};
          constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);
          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {
            sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);
            sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);
          }
          if (threadIdx.x == 0) {
            if constexpr (USE_SHARED_RED_BUFFER) {
              if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
                shared_red_buffer[0] = compute_mean_var(sum_local_group);
              } else {
                shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + 0] = sum_local_group;
              }
            } else {
              *reinterpret_cast<float2 *>(
                  &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                               virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                               // (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y +
                               virtual_block_idx_y) *
                              2]) = sum_local_group;
            }
          }
        }
      } else {
        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)
        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(ROWS_PER_IO)),
                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));
        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads");
        float2 sum_local_group = {0.f, 0.f};
        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;
          // TODO: map threads to both the CPG loop and the ROWS loop
          for (int local_c_loop = 0; local_c_loop < CPG; local_c_loop += GCD_VEC_CPG) {
            int c = local_group_idx * CPG + local_c_loop;
            if (C_PER_BLOCK % CPG == 0 || (c >= block_channel_start && c < block_channel_start + C_PER_BLOCK)) {
              for (int src_thread_tile_y = threadIdx.x % THREADS_PER_GROUP; src_thread_tile_y < ROWS_PER_IO;
                   src_thread_tile_y += THREADS_PER_GROUP) {
                int channel_idx = (c - block_channel_start) / GCD_VEC_CPG;
                channel_idx = channel_idx % (VEC_ELEMS / GCD_VEC_CPG) * (C_PER_BLOCK / VEC_ELEMS) +
                              channel_idx / (VEC_ELEMS / GCD_VEC_CPG);
                sum_local_group.x += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].x;
                sum_local_group.y += sum_per_channel_multi_group[channel_idx][src_thread_tile_y].y;
              }
            }
          }
        }
        static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle");
        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {
          sum_local_group.x += __shfl_xor_sync(FINAL_MASK, sum_local_group.x, mask, 32);
          sum_local_group.y += __shfl_xor_sync(FINAL_MASK, sum_local_group.y, mask, 32);
        }
        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          if constexpr (USE_SHARED_RED_BUFFER) {
            static_assert(HARDWARE_CLUSTER || VIRTUAL_CLUSTER_SIZE == 1, "no distributed shared memory");
            if constexpr (STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
              shared_red_buffer[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_local_group);
            } else {
              shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP] = sum_local_group;
            }
          } else {
            *reinterpret_cast<float2 *>(
                &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                             virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                             (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + virtual_block_idx_y) *
                            2]) = sum_local_group;
          }
        }
      }

      if constexpr (REQUIRES_WGRAD && wgrad_sync_method != WGRAD_SYNC_AT_LAST) {
        if (nc_scheduler.at_end(n)) {
          static_assert(PERSISTENT, "persistent is a must for reducing wgrad");
          if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {
            virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);
            wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(
                barrier_wgrad, blockIdx.x + blockIdx.y == 0);
          } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {
            virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);
            wgrad_sync_token =
                group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(
                    barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);
          } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GRID) {
            static_assert(!HARDWARE_CLUSTER,
                          "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GRID instead.");
            wgrad_sync_token = group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE, PERSISTENT>(
                barrier_wgrad, blockIdx.x + blockIdx.y == 0);
            group_barrier_wait(barrier_wgrad, wgrad_sync_token);
          } else if constexpr (wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP) {
            static_assert(!HARDWARE_CLUSTER,
                          "Distributed smem sync cannot reuse gmem sync. Use WGRAD_ARRIVE_AND_WAIT_GROUP instead.");
            wgrad_sync_token =
                group_barrier_arrive<NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE / (C / C_PER_CLUSTER), PERSISTENT>(
                    barrier_wgrad + virtual_cluster_idx_c, blockIdx.x + blockIdx.y / (C / C_PER_CLUSTER) == 0);
            group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);
          }
        } else {
          virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);
        }
      } else {
        virtual_cluster_sync<VIRTUAL_CLUSTER_SIZE, PERSISTENT, HARDWARE_CLUSTER>(barrier);
      }

      // Group sum
      __shared__ float2 mean_var[MAX_NUM_GROUPS_PER_BLOCK];
      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER) {
        // The number of threads to calculate the sum of each group (should be a power of 2 for warp reduce)
        constexpr int THREADS_PER_GROUP = std::min(std::min(32U, round_up_pow2(virtual_cluster_dim_y)),
                                                   round_up_pow2(BLOCK_DIM_X / MAX_NUM_GROUPS_PER_BLOCK / 2 + 1));
        static_assert(BLOCK_DIM_X >= MAX_NUM_GROUPS_PER_BLOCK * THREADS_PER_GROUP, "not enough threads");
        float2 sum_global_group = {0.f, 0.f};
        if (threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          if constexpr (C_PER_BLOCK % CPG == 0) {
            // Special case: no cross-virtual_cluster_dim_x reduction
            float2 buffer[up_div(virtual_cluster_dim_y, THREADS_PER_GROUP)];
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {
              float2 val;
              if constexpr (USE_SHARED_RED_BUFFER) {
                if constexpr (VIRTUAL_CLUSTER_SIZE == 1) {
                  val = shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];
                } else {
                  static_assert(HARDWARE_CLUSTER, "no distributed shared memory");
                  float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(
                      shared_red_buffer, i * virtual_cluster_dim_x + virtual_block_idx_x);
                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + threadIdx.x / THREADS_PER_GROUP];
                }
              } else {
                val = *reinterpret_cast<float2 const *>(
                    &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                                 virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                                 (threadIdx.x / THREADS_PER_GROUP) * virtual_cluster_dim_y + i) *
                                2]);
              }
              buffer[i / THREADS_PER_GROUP] = val;
            }
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < virtual_cluster_dim_y; i += THREADS_PER_GROUP) {
              float2 val = buffer[i / THREADS_PER_GROUP];
              sum_global_group.x += val.x;
              sum_global_group.y += val.y;
            }
          } else {
            // Common case: cross-virtual_cluster_dim_x reduction
            int local_group_idx = block_group_start + threadIdx.x / THREADS_PER_GROUP;
            for (int i = threadIdx.x % THREADS_PER_GROUP; i < VIRTUAL_CLUSTER_SIZE; i += THREADS_PER_GROUP) {
              int src_virtual_block_idx_x = i % virtual_cluster_dim_x;
              int src_block_channel_start = src_virtual_block_idx_x * C_PER_BLOCK + c_loop * C_PER_CLUSTER;
              int src_block_group_start = src_block_channel_start / CPG;
              int relative_group_idx = local_group_idx - src_block_group_start;
              if (0 <= relative_group_idx && relative_group_idx < MAX_NUM_GROUPS_PER_BLOCK) {
                float2 val;
                if constexpr (USE_SHARED_RED_BUFFER) {
                  static_assert(HARDWARE_CLUSTER, "no distributed shared memory");
                  static_assert(VIRTUAL_CLUSTER_SIZE != 1,
                                "layout error: should not add (step * MAX_NUM_GROUPS_PER_BLOCK)");
                  float2 const *src_shared_red_buffer = cg::this_cluster().map_shared_rank(shared_red_buffer, i);
                  val = src_shared_red_buffer[step * MAX_NUM_GROUPS_PER_BLOCK + relative_group_idx];
                } else {
                  val = *reinterpret_cast<float2 const *>(
                      &red_buffer[((step * gridDim.y + blockIdx.y) * VIRTUAL_CLUSTER_SIZE * MAX_NUM_GROUPS_PER_BLOCK +
                                   src_virtual_block_idx_x * virtual_cluster_dim_y * MAX_NUM_GROUPS_PER_BLOCK +
                                   relative_group_idx * virtual_cluster_dim_y + i / virtual_cluster_dim_x) *
                                  2]);
                }
                sum_global_group.x += val.x;
                sum_global_group.y += val.y;
              }
            }
          }
        }
        if constexpr (USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {
          // Need cluster sync after distributed shared memory access, otherwise behavior is undefined
          if constexpr (PERSISTENT) {
            if (nc_scheduler.at_end(n)) {
              cg::this_cluster().barrier_arrive();
            }
          } else {
            cg::this_cluster().barrier_arrive();
          }
        }
        static_assert(32 % THREADS_PER_GROUP == 0, "cannot shuffle");
        for (int mask = THREADS_PER_GROUP / 2; mask > 0; mask >>= 1) {
          sum_global_group.x += __shfl_xor_sync(FINAL_MASK, sum_global_group.x, mask, 32);
          sum_global_group.y += __shfl_xor_sync(FINAL_MASK, sum_global_group.y, mask, 32);
        }
        if (threadIdx.x % THREADS_PER_GROUP == 0 && threadIdx.x / THREADS_PER_GROUP < MAX_NUM_GROUPS_PER_BLOCK) {
          mean_var[threadIdx.x / THREADS_PER_GROUP] = compute_mean_var(sum_global_group);
        }
        __syncthreads();
      }

      auto get_mean_var = [&](int relative_group_idx) {
        return STORE_MEAN_VAR_IN_SHARED_RED_BUFFER ? shared_red_buffer[relative_group_idx]
                                                   : mean_var[relative_group_idx];
      };

      float frag_dyw[VEC_ELEMS / GCD_VEC_CPG];
      float frag_xdyw[VEC_ELEMS / GCD_VEC_CPG];
      for (int k = 0; k < VEC_ELEMS; k += GCD_VEC_CPG) {
        frag_dyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).x;
        frag_xdyw[k / GCD_VEC_CPG] = get_mean_var((thread_channel_start + k) / CPG - block_group_start).y;
      }

      for (int j = 0; j < ROWS_PER_BLOCK / ROWS_PER_IO; j++) {
        int64_t input_idx =
            n_loop * HW * C +
            (virtual_block_idx_y * ROWS_PER_BLOCK + j * ROWS_PER_IO + threadIdx.x / (C_PER_BLOCK / VEC_ELEMS)) * C +
            thread_channel_start;
        U ux;
        U udy;
        if constexpr (LOAD_TWICE) {
          ux = *reinterpret_cast<U const *>(&x[input_idx]);
          udy = *reinterpret_cast<U const *>(&grad_output[input_idx]);
        } else {
          ux = frag_x[j];
          udy = frag_dy[j];
        }
        U val;
        for (int k = 0; k < VEC_ELEMS; k++) {
          float rnorm = rsqrtf(frag_var[k / GCD_VEC_CPG] + eps);
          float x_norm = ((float)ux.data[k] - frag_mean[k / GCD_VEC_CPG]) * rnorm;  // TODO: store rsqrtf in mean_var
          float grad_gn = udy.data[k];
          if constexpr (SILU) {
            float x_gn = x_norm * (float)uw.data[k] + (float)ub.data[k];
            float s = 1.f / (1.f + expf(-x_gn));
            grad_gn *= s * (1.f + x_gn * (1.f - s));
          }
          val.data[k] =
              (grad_gn * (float)uw.data[k] - frag_dyw[k / GCD_VEC_CPG] - frag_xdyw[k / GCD_VEC_CPG] * x_norm) * rnorm;
        }
        *reinterpret_cast<U *>(&grad_input[input_idx]) = val;
      }

      if constexpr (!STORE_MEAN_VAR_IN_SHARED_RED_BUFFER && USE_SHARED_RED_BUFFER && VIRTUAL_CLUSTER_SIZE > 1) {
        if constexpr (PERSISTENT) {
          if (nc_scheduler.at_end(n)) {
            cg::this_cluster().barrier_wait();
          }
        } else {
          cg::this_cluster().barrier_wait();
        }
      }

      if constexpr (!PERSISTENT) {
        break;
      }
      step ^= 1;
    }

    // Wgrad sum
    if constexpr (REQUIRES_WGRAD) {
      static_assert(PERSISTENT, "cannot reduce wgrad");
      static_assert(C % 32 == 0, "cannot reduce wgrad");
      if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GRID) {
        group_barrier_wait(barrier_wgrad, wgrad_sync_token);
      } else if constexpr (wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP) {
        group_barrier_wait(barrier_wgrad + virtual_cluster_idx_c, wgrad_sync_token);
      } else if constexpr (wgrad_sync_method == WGRAD_SYNC_AT_LAST) {
        cg::this_grid().sync();
      }

      // If group sync, map blocks that are responsible for the same range of channels to these channels (named "split
      // channels"); otherwise, map all blocks to all channels.
      constexpr bool split_channels =
          wgrad_sync_method == WGRAD_ARRIVE_AND_WAIT_GROUP || wgrad_sync_method == WGRAD_REUSE_SUM_SYNC_GROUP;

      for (int c = split_channels ? virtual_cluster_idx_c * C_PER_CLUSTER +
                                        32 * (blockIdx.y / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE + blockIdx.x)
                                  : 32 * (blockIdx.y * VIRTUAL_CLUSTER_SIZE + blockIdx.x);
           split_channels ? c < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER : c < C;
           c += split_channels ? 32 * (NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER) * VIRTUAL_CLUSTER_SIZE)
                               : 32 * (NUM_VIRTUAL_CLUSTERS * VIRTUAL_CLUSTER_SIZE)) {
        int64_t rows = (CONSTANT_C_LOOP ? std::min(n, (int64_t)NUM_VIRTUAL_CLUSTERS / (C / C_PER_CLUSTER)) : n) *
                       virtual_cluster_dim_y;
        float sum_wgrad = 0.f;
        float sum_bgrad = 0.f;
        if ((split_channels &&
             (C_PER_CLUSTER % 32 == 0 || c + threadIdx.x % 32 < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) ||
            (!split_channels && (C % 32 == 0 || c + threadIdx.x % 32 < C))) {
          for (int64_t i = threadIdx.x / 32; i < rows; i += BLOCK_DIM_X / 32) {
            float2 val = *reinterpret_cast<float2 const *>(&red_buffer_wgrad[(i * C + c + threadIdx.x % 32) * 2]);
            sum_wgrad += val.x;
            sum_bgrad += val.y;
          }
        }
        constexpr int warp_num_pow2 = round_up_pow2(BLOCK_DIM_X / 32);
        union_smem.transpose_buffer
            .wgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] =
            sum_wgrad;
        union_smem.transpose_buffer
            .bgrad_buffer[threadIdx.x / 32][(threadIdx.x % 32) ^ ((threadIdx.x / 32) * (32 / warp_num_pow2))] =
            sum_bgrad;
        __syncthreads();
        for (int i = threadIdx.x / warp_num_pow2;
             i < 32 &&
             ((split_channels && (C_PER_CLUSTER % 32 == 0 || c + i < (virtual_cluster_idx_c + 1) * C_PER_CLUSTER)) ||
              (!split_channels && (C % 32 == 0 || c + i < C)));
             i += BLOCK_DIM_X / warp_num_pow2) {
          int j = threadIdx.x % warp_num_pow2;
          float sum_wgrad =
              j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.wgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f;
          float sum_bgrad =
              j < BLOCK_DIM_X / 32 ? union_smem.transpose_buffer.bgrad_buffer[j][i ^ (j * (32 / warp_num_pow2))] : 0.f;
          for (int mask = warp_num_pow2 / 2; mask > 0; mask >>= 1) {
            sum_wgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_wgrad, mask, warp_num_pow2);
            sum_bgrad += __shfl_xor_sync((uint64_t(1) << warp_num_pow2) - 1, sum_bgrad, mask, warp_num_pow2);
          }
          if (j == 0) {
            grad_weight[c + i] = sum_wgrad;
            grad_bias[c + i] = sum_bgrad;
          }
        }
        __syncthreads();
      }
    }
  }
}

}  // namespace group_norm_v2
